{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "89350a42-2174-4a8f-b183-b77669e72d5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.nn.utils.rnn import pack_sequence\n",
    "import random\n",
    "    \n",
    "class MELD_loader(Dataset):\n",
    "    def __init__(self, txt_file, dataclass):\n",
    "        self.dialogs = []\n",
    "        \n",
    "        f = open(txt_file, 'r')\n",
    "        dataset = f.readlines()\n",
    "        f.close()\n",
    "        \n",
    "        temp_speakerList = []\n",
    "        context = []\n",
    "        context_speaker = []\n",
    "        self.speakerNum = []\n",
    "        # 'anger', 'disgust', 'fear', 'joy', 'neutral', 'sadness', 'surprise'\n",
    "        emodict = {'anger': \"anger\", 'disgust': \"disgust\", 'fear': \"fear\", 'joy': \"joy\", 'neutral': \"neutral\", 'sadness': \"sad\", 'surprise': 'surprise'}\n",
    "        self.sentidict = {'positive': [\"joy\"], 'negative': [\"anger\", \"disgust\", \"fear\", \"sadness\"], 'neutral': [\"neutral\", \"surprise\"]}\n",
    "        self.emoSet = set()\n",
    "        self.sentiSet = set()\n",
    "        for i, data in enumerate(dataset):\n",
    "            if i < 2:\n",
    "                continue\n",
    "            if data == '\\n' and len(self.dialogs) > 0:\n",
    "                self.speakerNum.append(len(temp_speakerList))\n",
    "                temp_speakerList = []\n",
    "                context = []\n",
    "                context_speaker = []\n",
    "                continue\n",
    "            speaker, utt, emo, senti = data.strip().split('\\t')\n",
    "            context.append(utt)\n",
    "            if speaker not in temp_speakerList:\n",
    "                temp_speakerList.append(speaker)\n",
    "            speakerCLS = temp_speakerList.index(speaker)\n",
    "            context_speaker.append(speakerCLS)\n",
    "            \n",
    "            self.dialogs.append([context_speaker[:], context[:], emodict[emo], senti])\n",
    "            self.emoSet.add(emodict[emo])\n",
    "            self.sentiSet.add(senti)\n",
    "        \n",
    "        self.emoList = sorted(self.emoSet)  \n",
    "        self.sentiList = sorted(self.sentiSet)\n",
    "        if dataclass == 'emotion':\n",
    "            self.labelList = self.emoList\n",
    "        else:\n",
    "            self.labelList = self.sentiList        \n",
    "        self.speakerNum.append(len(temp_speakerList))\n",
    "        random.shuffle(self.dialogs)\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.dialogs)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.dialogs[idx], self.labelList, self.sentidict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b7c7a5ab-ff60-42ce-964a-54068d4fe59f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = MELD_loader('./multi/MELD_train.txt', 'emotion')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4d010288-c7e0-4340-a6de-e35150af4b58",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([[0, 1, 2, 0, 2, 2, 2, 2, 2, 0, 2, 1, 0, 1],\n",
       "  ['....and 12, 22, 18, four...  What?',\n",
       "   'I spelled out boobies.',\n",
       "   'Ross, but me down for another box of the mint treasures, okay. Where, where are the mint treasures?',\n",
       "   'Ah, we’re out. I sold them all.',\n",
       "   'What?',\n",
       "   'No.',\n",
       "   'No, just, just, just a couple more boxes.',\n",
       "   'It-it-it’s no big deal, all right, I’m-I’m cool.',\n",
       "   'You gotta help me out with a couple more boxes!',\n",
       "   'Mon, look at yourself. You have cookie on your neck.',\n",
       "   'Oh God!',\n",
       "   'So, how many have you sold so far?',\n",
       "   'Check this out. Five hundred and seventeen boxes!',\n",
       "   'Oh my God, how did you do that?'],\n",
       "  'surprise',\n",
       "  'positive'],\n",
       " ['anger', 'disgust', 'fear', 'joy', 'neutral', 'sad', 'surprise'],\n",
       " {'positive': ['joy'],\n",
       "  'negative': ['anger', 'disgust', 'fear', 'sadness'],\n",
       "  'neutral': ['neutral', 'surprise']})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2f1446f-b647-4e8c-b9ac-335ea1ec5de0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
